#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
# Some parts of the code are borrowed from: https://github.com/pytorch/vision
# with the following license:
#
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
# converted using the pretrained model given by: https://github.com/shicai/MobileNet-Caffe
# BSD 3-Clause License
#
# Copyright (c) 2017-, Shicai Yang
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################


import torch
import torch.nn.functional as F
from ... import xnn

__all__ = ['MobileNetV2Shicai','mobilenetv2_shicai']


###################################################
def get_config():
    model_config = xnn.utils.ConfigNode()
    return model_config


###################################################
def load_weights(model, weight_file):
    if weight_file == None:
        return
    #
    #data = torch.load(weight_file)
    #model.load_state_dict(data)
    model = xnn.utils.load_weights(model, weight_file)
#


class MobileNetV2Shicai(torch.nn.Module):
    def __init__(self, model_config):
        super().__init__()

        self.conv1 = self.__conv(2, name='conv1', in_channels=3, out_channels=32, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=False)
        self.conv1_bn = self.__batch_normalization(2, 'conv1/bn', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
        self.conv2_1_expand = self.__conv(2, name='conv2_1/expand', in_channels=32, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv2_1_expand_bn = self.__batch_normalization(2, 'conv2_1/expand/bn', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
        self.conv2_1_dwise = self.__conv(2, name='conv2_1/dwise', in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        self.conv2_1_dwise_bn = self.__batch_normalization(2, 'conv2_1/dwise/bn', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
        self.conv2_1_linear = self.__conv(2, name='conv2_1/linear', in_channels=32, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv2_1_linear_bn = self.__batch_normalization(2, 'conv2_1/linear/bn', num_features=16, eps=9.999999747378752e-06, momentum=0.0)
        self.conv2_2_expand = self.__conv(2, name='conv2_2/expand', in_channels=16, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv2_2_expand_bn = self.__batch_normalization(2, 'conv2_2/expand/bn', num_features=96, eps=9.999999747378752e-06, momentum=0.0)
        self.conv2_2_dwise = self.__conv(2, name='conv2_2/dwise', in_channels=96, out_channels=96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)
        self.conv2_2_dwise_bn = self.__batch_normalization(2, 'conv2_2/dwise/bn', num_features=96, eps=9.999999747378752e-06, momentum=0.0)
        self.conv2_2_linear = self.__conv(2, name='conv2_2/linear', in_channels=96, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv2_2_linear_bn = self.__batch_normalization(2, 'conv2_2/linear/bn', num_features=24, eps=9.999999747378752e-06, momentum=0.0)
        self.conv3_1_expand = self.__conv(2, name='conv3_1/expand', in_channels=24, out_channels=144, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv3_1_expand_bn = self.__batch_normalization(2, 'conv3_1/expand/bn', num_features=144, eps=9.999999747378752e-06, momentum=0.0)
        self.conv3_1_dwise = self.__conv(2, name='conv3_1/dwise', in_channels=144, out_channels=144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)
        self.conv3_1_dwise_bn = self.__batch_normalization(2, 'conv3_1/dwise/bn', num_features=144, eps=9.999999747378752e-06, momentum=0.0)
        self.conv3_1_linear = self.__conv(2, name='conv3_1/linear', in_channels=144, out_channels=24, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv3_1_linear_bn = self.__batch_normalization(2, 'conv3_1/linear/bn', num_features=24, eps=9.999999747378752e-06, momentum=0.0)
        self.conv3_2_expand = self.__conv(2, name='conv3_2/expand', in_channels=24, out_channels=144, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv3_2_expand_bn = self.__batch_normalization(2, 'conv3_2/expand/bn', num_features=144, eps=9.999999747378752e-06, momentum=0.0)
        self.conv3_2_dwise = self.__conv(2, name='conv3_2/dwise', in_channels=144, out_channels=144, kernel_size=(3, 3), stride=(2, 2), groups=144, bias=False)
        self.conv3_2_dwise_bn = self.__batch_normalization(2, 'conv3_2/dwise/bn', num_features=144, eps=9.999999747378752e-06, momentum=0.0)
        self.conv3_2_linear = self.__conv(2, name='conv3_2/linear', in_channels=144, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv3_2_linear_bn = self.__batch_normalization(2, 'conv3_2/linear/bn', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_1_expand = self.__conv(2, name='conv4_1/expand', in_channels=32, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_1_expand_bn = self.__batch_normalization(2, 'conv4_1/expand/bn', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_1_dwise = self.__conv(2, name='conv4_1/dwise', in_channels=192, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=192, bias=False)
        self.conv4_1_dwise_bn = self.__batch_normalization(2, 'conv4_1/dwise/bn', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_1_linear = self.__conv(2, name='conv4_1/linear', in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_1_linear_bn = self.__batch_normalization(2, 'conv4_1/linear/bn', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_2_expand = self.__conv(2, name='conv4_2/expand', in_channels=32, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_2_expand_bn = self.__batch_normalization(2, 'conv4_2/expand/bn', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_2_dwise = self.__conv(2, name='conv4_2/dwise', in_channels=192, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=192, bias=False)
        self.conv4_2_dwise_bn = self.__batch_normalization(2, 'conv4_2/dwise/bn', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_2_linear = self.__conv(2, name='conv4_2/linear', in_channels=192, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_2_linear_bn = self.__batch_normalization(2, 'conv4_2/linear/bn', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_3_expand = self.__conv(2, name='conv4_3/expand', in_channels=32, out_channels=192, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_3_expand_bn = self.__batch_normalization(2, 'conv4_3/expand/bn', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_3_dwise = self.__conv(2, name='conv4_3/dwise', in_channels=192, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=192, bias=False)
        self.conv4_3_dwise_bn = self.__batch_normalization(2, 'conv4_3/dwise/bn', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_3_linear = self.__conv(2, name='conv4_3/linear', in_channels=192, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_3_linear_bn = self.__batch_normalization(2, 'conv4_3/linear/bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_4_expand = self.__conv(2, name='conv4_4/expand', in_channels=64, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_4_expand_bn = self.__batch_normalization(2, 'conv4_4/expand/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_4_dwise = self.__conv(2, name='conv4_4/dwise', in_channels=384, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=384, bias=False)
        self.conv4_4_dwise_bn = self.__batch_normalization(2, 'conv4_4/dwise/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_4_linear = self.__conv(2, name='conv4_4/linear', in_channels=384, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_4_linear_bn = self.__batch_normalization(2, 'conv4_4/linear/bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_5_expand = self.__conv(2, name='conv4_5/expand', in_channels=64, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_5_expand_bn = self.__batch_normalization(2, 'conv4_5/expand/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_5_dwise = self.__conv(2, name='conv4_5/dwise', in_channels=384, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=384, bias=False)
        self.conv4_5_dwise_bn = self.__batch_normalization(2, 'conv4_5/dwise/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_5_linear = self.__conv(2, name='conv4_5/linear', in_channels=384, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_5_linear_bn = self.__batch_normalization(2, 'conv4_5/linear/bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_6_expand = self.__conv(2, name='conv4_6/expand', in_channels=64, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_6_expand_bn = self.__batch_normalization(2, 'conv4_6/expand/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_6_dwise = self.__conv(2, name='conv4_6/dwise', in_channels=384, out_channels=384, kernel_size=(3, 3), stride=(1, 1), groups=384, bias=False)
        self.conv4_6_dwise_bn = self.__batch_normalization(2, 'conv4_6/dwise/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_6_linear = self.__conv(2, name='conv4_6/linear', in_channels=384, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_6_linear_bn = self.__batch_normalization(2, 'conv4_6/linear/bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_7_expand = self.__conv(2, name='conv4_7/expand', in_channels=64, out_channels=384, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_7_expand_bn = self.__batch_normalization(2, 'conv4_7/expand/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_7_dwise = self.__conv(2, name='conv4_7/dwise', in_channels=384, out_channels=384, kernel_size=(3, 3), stride=(2, 2), groups=384, bias=False)
        self.conv4_7_dwise_bn = self.__batch_normalization(2, 'conv4_7/dwise/bn', num_features=384, eps=9.999999747378752e-06, momentum=0.0)
        self.conv4_7_linear = self.__conv(2, name='conv4_7/linear', in_channels=384, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv4_7_linear_bn = self.__batch_normalization(2, 'conv4_7/linear/bn', num_features=96, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_1_expand = self.__conv(2, name='conv5_1/expand', in_channels=96, out_channels=576, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv5_1_expand_bn = self.__batch_normalization(2, 'conv5_1/expand/bn', num_features=576, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_1_dwise = self.__conv(2, name='conv5_1/dwise', in_channels=576, out_channels=576, kernel_size=(3, 3), stride=(1, 1), groups=576, bias=False)
        self.conv5_1_dwise_bn = self.__batch_normalization(2, 'conv5_1/dwise/bn', num_features=576, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_1_linear = self.__conv(2, name='conv5_1/linear', in_channels=576, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv5_1_linear_bn = self.__batch_normalization(2, 'conv5_1/linear/bn', num_features=96, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_2_expand = self.__conv(2, name='conv5_2/expand', in_channels=96, out_channels=576, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv5_2_expand_bn = self.__batch_normalization(2, 'conv5_2/expand/bn', num_features=576, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_2_dwise = self.__conv(2, name='conv5_2/dwise', in_channels=576, out_channels=576, kernel_size=(3, 3), stride=(1, 1), groups=576, bias=False)
        self.conv5_2_dwise_bn = self.__batch_normalization(2, 'conv5_2/dwise/bn', num_features=576, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_2_linear = self.__conv(2, name='conv5_2/linear', in_channels=576, out_channels=96, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv5_2_linear_bn = self.__batch_normalization(2, 'conv5_2/linear/bn', num_features=96, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_3_expand = self.__conv(2, name='conv5_3/expand', in_channels=96, out_channels=576, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv5_3_expand_bn = self.__batch_normalization(2, 'conv5_3/expand/bn', num_features=576, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_3_dwise = self.__conv(2, name='conv5_3/dwise', in_channels=576, out_channels=576, kernel_size=(3, 3), stride=(2, 2), groups=576, bias=False)
        self.conv5_3_dwise_bn = self.__batch_normalization(2, 'conv5_3/dwise/bn', num_features=576, eps=9.999999747378752e-06, momentum=0.0)
        self.conv5_3_linear = self.__conv(2, name='conv5_3/linear', in_channels=576, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv5_3_linear_bn = self.__batch_normalization(2, 'conv5_3/linear/bn', num_features=160, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_1_expand = self.__conv(2, name='conv6_1/expand', in_channels=160, out_channels=960, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv6_1_expand_bn = self.__batch_normalization(2, 'conv6_1/expand/bn', num_features=960, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_1_dwise = self.__conv(2, name='conv6_1/dwise', in_channels=960, out_channels=960, kernel_size=(3, 3), stride=(1, 1), groups=960, bias=False)
        self.conv6_1_dwise_bn = self.__batch_normalization(2, 'conv6_1/dwise/bn', num_features=960, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_1_linear = self.__conv(2, name='conv6_1/linear', in_channels=960, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv6_1_linear_bn = self.__batch_normalization(2, 'conv6_1/linear/bn', num_features=160, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_2_expand = self.__conv(2, name='conv6_2/expand', in_channels=160, out_channels=960, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv6_2_expand_bn = self.__batch_normalization(2, 'conv6_2/expand/bn', num_features=960, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_2_dwise = self.__conv(2, name='conv6_2/dwise', in_channels=960, out_channels=960, kernel_size=(3, 3), stride=(1, 1), groups=960, bias=False)
        self.conv6_2_dwise_bn = self.__batch_normalization(2, 'conv6_2/dwise/bn', num_features=960, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_2_linear = self.__conv(2, name='conv6_2/linear', in_channels=960, out_channels=160, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv6_2_linear_bn = self.__batch_normalization(2, 'conv6_2/linear/bn', num_features=160, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_3_expand = self.__conv(2, name='conv6_3/expand', in_channels=160, out_channels=960, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv6_3_expand_bn = self.__batch_normalization(2, 'conv6_3/expand/bn', num_features=960, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_3_dwise = self.__conv(2, name='conv6_3/dwise', in_channels=960, out_channels=960, kernel_size=(3, 3), stride=(1, 1), groups=960, bias=False)
        self.conv6_3_dwise_bn = self.__batch_normalization(2, 'conv6_3/dwise/bn', num_features=960, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_3_linear = self.__conv(2, name='conv6_3/linear', in_channels=960, out_channels=320, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv6_3_linear_bn = self.__batch_normalization(2, 'conv6_3/linear/bn', num_features=320, eps=9.999999747378752e-06, momentum=0.0)
        self.conv6_4 = self.__conv(2, name='conv6_4', in_channels=320, out_channels=1280, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
        self.conv6_4_bn = self.__batch_normalization(2, 'conv6_4/bn', num_features=1280, eps=9.999999747378752e-06, momentum=0.0)
        self.fc7 = self.__conv(2, name='fc7', in_channels=1280, out_channels=1000, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True)

        # same relu cannot be used two times - define seperate ones
        self.relu1 = torch.nn.ReLU(inplace=True)
        self.relu2_1_expand = torch.nn.ReLU(inplace=True)
        self.relu2_1_dwise = torch.nn.ReLU(inplace=True)
        self.relu2_2_expand = torch.nn.ReLU(inplace=True)
        self.relu2_2_dwise = torch.nn.ReLU(inplace=True)
        self.relu3_1_expand = torch.nn.ReLU(inplace=True)
        self.relu3_1_dwise = torch.nn.ReLU(inplace=True)
        self.relu3_2_expand = torch.nn.ReLU(inplace=True)
        self.relu3_2_dwise = torch.nn.ReLU(inplace=True)
        self.relu4_1_expand = torch.nn.ReLU(inplace=True)
        self.relu4_1_dwise = torch.nn.ReLU(inplace=True)
        self.relu4_2_expand = torch.nn.ReLU(inplace=True)
        self.relu4_2_dwise = torch.nn.ReLU(inplace=True)
        self.relu4_3_expand = torch.nn.ReLU(inplace=True)
        self.relu4_3_dwise = torch.nn.ReLU(inplace=True)
        self.relu4_4_expand = torch.nn.ReLU(inplace=True)
        self.relu4_4_dwise = torch.nn.ReLU(inplace=True)
        self.relu4_5_expand = torch.nn.ReLU(inplace=True)
        self.relu4_5_dwise = torch.nn.ReLU(inplace=True)
        self.relu4_6_expand = torch.nn.ReLU(inplace=True)
        self.relu4_6_dwise = torch.nn.ReLU(inplace=True)
        self.relu4_7_expand = torch.nn.ReLU(inplace=True)
        self.relu4_7_dwise = torch.nn.ReLU(inplace=True)
        self.relu5_1_expand = torch.nn.ReLU(inplace=True)
        self.relu5_1_dwise = torch.nn.ReLU(inplace=True)
        self.relu5_2_expand = torch.nn.ReLU(inplace=True)
        self.relu5_2_dwise = torch.nn.ReLU(inplace=True)
        self.relu5_3_expand = torch.nn.ReLU(inplace=True)
        self.relu5_3_dwise = torch.nn.ReLU(inplace=True)
        self.relu6_1_expand = torch.nn.ReLU(inplace=True)
        self.relu6_1_dwise = torch.nn.ReLU(inplace=True)
        self.relu6_2_expand = torch.nn.ReLU(inplace=True)
        self.relu6_2_dwise = torch.nn.ReLU(inplace=True)
        self.relu6_3_expand = torch.nn.ReLU(inplace=True)
        self.relu6_3_dwise = torch.nn.ReLU(inplace=True)
        self.relu6_4 = torch.nn.ReLU(inplace=True)

        self.add = xnn.layers.AddBlock()
        self.avg = torch.nn.AdaptiveAvgPool2d(1)
        self.flatten = torch.nn.Flatten(start_dim=1)
        
        self._initialize_weights()


    def forward(self, x):
        conv1           = self.conv1(x)
        conv1_bn        = self.conv1_bn(conv1)
        relu1           = self.relu1(conv1_bn)
        conv2_1_expand  = self.conv2_1_expand(relu1)
        conv2_1_expand_bn = self.conv2_1_expand_bn(conv2_1_expand)
        relu2_1_expand  = self.relu2_1_expand(conv2_1_expand_bn)
        conv2_1_dwise   = self.conv2_1_dwise(relu2_1_expand)
        conv2_1_dwise_bn = self.conv2_1_dwise_bn(conv2_1_dwise)
        relu2_1_dwise   = self.relu2_1_dwise(conv2_1_dwise_bn)
        conv2_1_linear  = self.conv2_1_linear(relu2_1_dwise)
        conv2_1_linear_bn = self.conv2_1_linear_bn(conv2_1_linear)
        conv2_2_expand  = self.conv2_2_expand(conv2_1_linear_bn)
        conv2_2_expand_bn = self.conv2_2_expand_bn(conv2_2_expand)
        relu2_2_expand  = self.relu2_2_expand(conv2_2_expand_bn)
        conv2_2_dwise   = self.conv2_2_dwise(relu2_2_expand)
        conv2_2_dwise_bn = self.conv2_2_dwise_bn(conv2_2_dwise)
        relu2_2_dwise   = self.relu2_2_dwise(conv2_2_dwise_bn)
        conv2_2_linear  = self.conv2_2_linear(relu2_2_dwise)
        conv2_2_linear_bn = self.conv2_2_linear_bn(conv2_2_linear)
        conv3_1_expand  = self.conv3_1_expand(conv2_2_linear_bn)
        conv3_1_expand_bn = self.conv3_1_expand_bn(conv3_1_expand)
        relu3_1_expand  = self.relu3_1_expand(conv3_1_expand_bn)
        conv3_1_dwise   = self.conv3_1_dwise(relu3_1_expand)
        conv3_1_dwise_bn = self.conv3_1_dwise_bn(conv3_1_dwise)
        relu3_1_dwise   = self.relu3_1_dwise(conv3_1_dwise_bn)
        conv3_1_linear  = self.conv3_1_linear(relu3_1_dwise)
        conv3_1_linear_bn = self.conv3_1_linear_bn(conv3_1_linear)
        block_3_1       = self.add((conv2_2_linear_bn, conv3_1_linear_bn))
        conv3_2_expand  = self.conv3_2_expand(block_3_1)
        conv3_2_expand_bn = self.conv3_2_expand_bn(conv3_2_expand)
        relu3_2_expand  = self.relu3_2_expand(conv3_2_expand_bn)
        conv3_2_dwise   = self.conv3_2_dwise(relu3_2_expand)
        conv3_2_dwise_bn = self.conv3_2_dwise_bn(conv3_2_dwise)
        relu3_2_dwise   = self.relu3_2_dwise(conv3_2_dwise_bn)
        conv3_2_linear  = self.conv3_2_linear(relu3_2_dwise)
        conv3_2_linear_bn = self.conv3_2_linear_bn(conv3_2_linear)
        conv4_1_expand  = self.conv4_1_expand(conv3_2_linear_bn)
        conv4_1_expand_bn = self.conv4_1_expand_bn(conv4_1_expand)
        relu4_1_expand  = self.relu4_1_expand(conv4_1_expand_bn)
        conv4_1_dwise   = self.conv4_1_dwise(relu4_1_expand)
        conv4_1_dwise_bn = self.conv4_1_dwise_bn(conv4_1_dwise)
        relu4_1_dwise   = self.relu4_1_dwise(conv4_1_dwise_bn)
        conv4_1_linear  = self.conv4_1_linear(relu4_1_dwise)
        conv4_1_linear_bn = self.conv4_1_linear_bn(conv4_1_linear)
        block_4_1       = self.add((conv3_2_linear_bn, conv4_1_linear_bn))
        conv4_2_expand  = self.conv4_2_expand(block_4_1)
        conv4_2_expand_bn = self.conv4_2_expand_bn(conv4_2_expand)
        relu4_2_expand  = self.relu4_2_expand(conv4_2_expand_bn)
        conv4_2_dwise   = self.conv4_2_dwise(relu4_2_expand)
        conv4_2_dwise_bn = self.conv4_2_dwise_bn(conv4_2_dwise)
        relu4_2_dwise   = self.relu4_2_dwise(conv4_2_dwise_bn)
        conv4_2_linear  = self.conv4_2_linear(relu4_2_dwise)
        conv4_2_linear_bn = self.conv4_2_linear_bn(conv4_2_linear)
        block_4_2       = self.add((block_4_1, conv4_2_linear_bn))
        conv4_3_expand  = self.conv4_3_expand(block_4_2)
        conv4_3_expand_bn = self.conv4_3_expand_bn(conv4_3_expand)
        relu4_3_expand  = self.relu4_3_expand(conv4_3_expand_bn)
        conv4_3_dwise   = self.conv4_3_dwise(relu4_3_expand)
        conv4_3_dwise_bn = self.conv4_3_dwise_bn(conv4_3_dwise)
        relu4_3_dwise   = self.relu4_3_dwise(conv4_3_dwise_bn)
        conv4_3_linear  = self.conv4_3_linear(relu4_3_dwise)
        conv4_3_linear_bn = self.conv4_3_linear_bn(conv4_3_linear)
        conv4_4_expand  = self.conv4_4_expand(conv4_3_linear_bn)
        conv4_4_expand_bn = self.conv4_4_expand_bn(conv4_4_expand)
        relu4_4_expand  = self.relu4_4_expand(conv4_4_expand_bn)
        conv4_4_dwise   = self.conv4_4_dwise(relu4_4_expand)
        conv4_4_dwise_bn = self.conv4_4_dwise_bn(conv4_4_dwise)
        relu4_4_dwise   = self.relu4_4_dwise(conv4_4_dwise_bn)
        conv4_4_linear  = self.conv4_4_linear(relu4_4_dwise)
        conv4_4_linear_bn = self.conv4_4_linear_bn(conv4_4_linear)
        block_4_4       = self.add((conv4_3_linear_bn, conv4_4_linear_bn))
        conv4_5_expand  = self.conv4_5_expand(block_4_4)
        conv4_5_expand_bn = self.conv4_5_expand_bn(conv4_5_expand)
        relu4_5_expand  = self.relu4_5_expand(conv4_5_expand_bn)
        conv4_5_dwise   = self.conv4_5_dwise(relu4_5_expand)
        conv4_5_dwise_bn = self.conv4_5_dwise_bn(conv4_5_dwise)
        relu4_5_dwise   = self.relu4_5_dwise(conv4_5_dwise_bn)
        conv4_5_linear  = self.conv4_5_linear(relu4_5_dwise)
        conv4_5_linear_bn = self.conv4_5_linear_bn(conv4_5_linear)
        block_4_5       = self.add((block_4_4, conv4_5_linear_bn))
        conv4_6_expand  = self.conv4_6_expand(block_4_5)
        conv4_6_expand_bn = self.conv4_6_expand_bn(conv4_6_expand)
        relu4_6_expand  = self.relu4_6_expand(conv4_6_expand_bn)
        conv4_6_dwise   = self.conv4_6_dwise(relu4_6_expand)
        conv4_6_dwise_bn = self.conv4_6_dwise_bn(conv4_6_dwise)
        relu4_6_dwise   = self.relu4_6_dwise(conv4_6_dwise_bn)
        conv4_6_linear  = self.conv4_6_linear(relu4_6_dwise)
        conv4_6_linear_bn = self.conv4_6_linear_bn(conv4_6_linear)
        block_4_6       = self.add((block_4_5, conv4_6_linear_bn))
        conv4_7_expand  = self.conv4_7_expand(block_4_6)
        conv4_7_expand_bn = self.conv4_7_expand_bn(conv4_7_expand)
        relu4_7_expand  = self.relu4_7_expand(conv4_7_expand_bn)
        conv4_7_dwise   = self.conv4_7_dwise(relu4_7_expand)
        conv4_7_dwise_bn = self.conv4_7_dwise_bn(conv4_7_dwise)
        relu4_7_dwise   = self.relu4_7_dwise(conv4_7_dwise_bn)
        conv4_7_linear  = self.conv4_7_linear(relu4_7_dwise)
        conv4_7_linear_bn = self.conv4_7_linear_bn(conv4_7_linear)
        conv5_1_expand  = self.conv5_1_expand(conv4_7_linear_bn)
        conv5_1_expand_bn = self.conv5_1_expand_bn(conv5_1_expand)
        relu5_1_expand  = self.relu5_1_expand(conv5_1_expand_bn)
        conv5_1_dwise   = self.conv5_1_dwise(relu5_1_expand)
        conv5_1_dwise_bn = self.conv5_1_dwise_bn(conv5_1_dwise)
        relu5_1_dwise   = self.relu5_1_dwise(conv5_1_dwise_bn)
        conv5_1_linear  = self.conv5_1_linear(relu5_1_dwise)
        conv5_1_linear_bn = self.conv5_1_linear_bn(conv5_1_linear)
        block_5_1       = self.add((conv4_7_linear_bn, conv5_1_linear_bn))
        conv5_2_expand  = self.conv5_2_expand(block_5_1)
        conv5_2_expand_bn = self.conv5_2_expand_bn(conv5_2_expand)
        relu5_2_expand  = self.relu5_2_expand(conv5_2_expand_bn)
        conv5_2_dwise   = self.conv5_2_dwise(relu5_2_expand)
        conv5_2_dwise_bn = self.conv5_2_dwise_bn(conv5_2_dwise)
        relu5_2_dwise   = self.relu5_2_dwise(conv5_2_dwise_bn)
        conv5_2_linear  = self.conv5_2_linear(relu5_2_dwise)
        conv5_2_linear_bn = self.conv5_2_linear_bn(conv5_2_linear)
        block_5_2       = self.add((block_5_1, conv5_2_linear_bn))
        conv5_3_expand  = self.conv5_3_expand(block_5_2)
        conv5_3_expand_bn = self.conv5_3_expand_bn(conv5_3_expand)
        relu5_3_expand  = self.relu5_3_expand(conv5_3_expand_bn)
        conv5_3_dwise   = self.conv5_3_dwise(relu5_3_expand)
        conv5_3_dwise_bn = self.conv5_3_dwise_bn(conv5_3_dwise)
        relu5_3_dwise   = self.relu5_3_dwise(conv5_3_dwise_bn)
        conv5_3_linear  = self.conv5_3_linear(relu5_3_dwise)
        conv5_3_linear_bn = self.conv5_3_linear_bn(conv5_3_linear)
        conv6_1_expand  = self.conv6_1_expand(conv5_3_linear_bn)
        conv6_1_expand_bn = self.conv6_1_expand_bn(conv6_1_expand)
        relu6_1_expand  = self.relu6_1_expand(conv6_1_expand_bn)
        conv6_1_dwise   = self.conv6_1_dwise(relu6_1_expand)
        conv6_1_dwise_bn = self.conv6_1_dwise_bn(conv6_1_dwise)
        relu6_1_dwise   = self.relu6_1_dwise(conv6_1_dwise_bn)
        conv6_1_linear  = self.conv6_1_linear(relu6_1_dwise)
        conv6_1_linear_bn = self.conv6_1_linear_bn(conv6_1_linear)
        block_6_1       = self.add((conv5_3_linear_bn, conv6_1_linear_bn))
        conv6_2_expand  = self.conv6_2_expand(block_6_1)
        conv6_2_expand_bn = self.conv6_2_expand_bn(conv6_2_expand)
        relu6_2_expand  = self.relu6_2_expand(conv6_2_expand_bn)
        conv6_2_dwise   = self.conv6_2_dwise(relu6_2_expand)
        conv6_2_dwise_bn = self.conv6_2_dwise_bn(conv6_2_dwise)
        relu6_2_dwise   = self.relu6_2_dwise(conv6_2_dwise_bn)
        conv6_2_linear  = self.conv6_2_linear(relu6_2_dwise)
        conv6_2_linear_bn = self.conv6_2_linear_bn(conv6_2_linear)
        block_6_2       = self.add((block_6_1, conv6_2_linear_bn))
        conv6_3_expand  = self.conv6_3_expand(block_6_2)
        conv6_3_expand_bn = self.conv6_3_expand_bn(conv6_3_expand)
        relu6_3_expand  = self.relu6_3_expand(conv6_3_expand_bn)
        conv6_3_dwise   = self.conv6_3_dwise(relu6_3_expand)
        conv6_3_dwise_bn = self.conv6_3_dwise_bn(conv6_3_dwise)
        relu6_3_dwise   = self.relu6_3_dwise(conv6_3_dwise_bn)
        conv6_3_linear  = self.conv6_3_linear(relu6_3_dwise)
        conv6_3_linear_bn = self.conv6_3_linear_bn(conv6_3_linear)
        conv6_4         = self.conv6_4(conv6_3_linear_bn)
        conv6_4_bn      = self.conv6_4_bn(conv6_4)
        relu6_4         = self.relu6_4(conv6_4_bn)
        pool6           = self.avg(relu6_4)
        fc7             = self.fc7(pool6)
        out             =  self.flatten(fc7)
        return out


    @staticmethod
    def __batch_normalization(dim, name, **kwargs):
        if   dim == 1:  layer = torch.nn.BatchNorm1d(**kwargs)
        elif dim == 2:  layer = torch.nn.BatchNorm2d(**kwargs)
        elif dim == 3:  layer = torch.nn.BatchNorm3d(**kwargs)
        else:           raise NotImplementedError()
        return layer

    @staticmethod
    def __conv(dim, name, **kwargs):
        if 'padding' not in kwargs.keys():
            kernel_size = kwargs['kernel_size']
            padding = (kernel_size[0]//2,kernel_size[1]//2) if isinstance(kernel_size,(list,tuple)) else kernel_size//2
            kwargs['padding'] = padding
        #
        if   dim == 1:  layer = torch.nn.Conv1d(**kwargs)
        elif dim == 2:  layer = torch.nn.Conv2d(**kwargs)
        elif dim == 3:  layer = torch.nn.Conv3d(**kwargs)
        else:           raise NotImplementedError()
        return layer

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, torch.nn.BatchNorm2d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, torch.nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


class mobilenetv2_shicai(MobileNetV2Shicai):
    def __init__(self, model_config=None, pretrained=None, **kwargs):
        model_config = get_config().merge_from(model_config)
        super().__init__(model_config)
        if pretrained:
            load_weights(self, pretrained)


